package com.icesoft.base.manager.security.login.password;

import cn.hutool.crypto.CryptoException;
import com.icesoft.base.manager.entity.system.SysUserLoginRecord;
import com.icesoft.base.manager.security.config.LoginPropertySetting;
import com.icesoft.base.manager.security.config.SecurityConst;
import com.icesoft.base.manager.service.system.SysUserLoginRecordService;
import com.icesoft.core.web.suppose.safehttp.SafeRequestConst;
import com.icesoft.core.web.suppose.safehttp.service.CryptoKeyPairService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.web.util.WebUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Date;

@Slf4j
public class UsernamePasswordDecryptAuthenticationFilter extends UsernamePasswordAuthenticationFilter {

    @Autowired
    private CryptoKeyPairService cryptoKeyPairService;
    @Autowired
    private LoginPropertySetting loginPropertySetting;
    @Autowired
    private SysUserLoginRecordService sysUserLoginRecordService;

    public UsernamePasswordDecryptAuthenticationFilter() {
        this.setRequiresAuthenticationRequestMatcher(new AntPathRequestMatcher(SecurityConst.LOGIN_PROCESS_API, "POST"));
    }

    @Override
    public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException {
        if (loginPropertySetting.isEnableLoginCode()) {
            String code = request.getParameter("validCode");
            if (code == null) {
                throw new InternalAuthenticationServiceException("请提供验证码参数");
            }
            String codeSession = (String) WebUtils.getSessionAttribute(request, SecurityConst.SESSION_CAPTCHA);
            if (codeSession == null) {
                throw new BadCredentialsException("请先获取验证码");
            }
            WebUtils.setSessionAttribute(request, SecurityConst.SESSION_CAPTCHA, null);
            if (!codeSession.equalsIgnoreCase(code)) {
                throw new InternalAuthenticationServiceException("验证码错误");
            }
        }
        return super.attemptAuthentication(request, response);
    }

    @Override
    protected String obtainPassword(HttpServletRequest request) {
        String password = super.obtainPassword(request);
        if (password != null) {
            try {
                password = cryptoKeyPairService.privateDecrypt(obtainKeyId(request), password);
            } catch (CryptoException e) {
                log.error("账号密码解码错误:password={}", password);
                throw new AuthenticationServiceException("页面过期，请刷新重试");
            }
        }
        return password;
    }

    @Override
    protected String obtainUsername(HttpServletRequest request) {

        String username = super.obtainUsername(request);
        if (username != null) {
            try {
                username = cryptoKeyPairService.privateDecrypt(obtainKeyId(request), username);
            } catch (Exception e) {
                log.error("账号密码解码错误:username={}", username);
                throw new AuthenticationServiceException("页面过期，请刷新重试");
            }
        }
        int maxLimitCount = loginPropertySetting.getMaxLimitCount();
        if (maxLimitCount > 0) {
            long preCreateTime = System.currentTimeMillis() - loginPropertySetting.geMaxLimitCountTimeMinute() * 60 * 1000L;
            int errorCount = sysUserLoginRecordService.errorCountByUsernameAndGeCreateTime(username, new Date(preCreateTime));
            if (errorCount >= maxLimitCount) {
                SysUserLoginRecord loginRecord = sysUserLoginRecordService.loadLatestLoginFailureRecord(username);
                if (System.currentTimeMillis() - loginRecord.getLoginTime().getTime() < loginPropertySetting.geLimitIntervalMinute() * 60 * 1000L) {
                    throw new AuthenticationServiceException("账号登录错误次数过多，请稍后再试");
                }

            }
        }

        return username;
    }

    private String obtainKeyId(HttpServletRequest request) {
        String keyId = request.getHeader(SafeRequestConst.HEAD_KEY_ID_PARAM_NAME);
        if (keyId == null) {
            throw new AuthenticationServiceException("缺少请求头：" + SafeRequestConst.HEAD_KEY_ID_PARAM_NAME);
        }
        return keyId;
    }
}
